{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: generating and running code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Warning: This notebook runs LLM-generated code without any checks. Run at your own risk."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loading a code model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-12-05 21:22:33.078594: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2023-12-05 21:22:33.148003: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "from guidance import models, gen\n",
    "from guidance.library._gen import will_gen\n",
    "from guidance import capture, one_or_more, any_char, zero_or_more, commit_point, select\n",
    "import guidance\n",
    "import re\n",
    "base_path = '/home/marcotcr_google_com/work/models/'\n",
    "model_path = base_path + 'mistral-7b-codealpaca-lora.Q8_0.gguf'\n",
    "mistral = models.LlamaCpp(model_path, n_gpu_layers=-1, n_ctx=4096)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loading the HumanEval dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "dataset = load_dataset(\"openai_humaneval\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's write a very simple baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import guidance\n",
    "@guidance\n",
    "def baseline(lm, prompt):\n",
    "    r = re.findall('def (.*?)\\(', prompt)\n",
    "    name = r[-1]\n",
    "    lm += f'Here is an implementation of {name}:\\n'\n",
    "    lm += '```python\\n' + prompt + gen(max_tokens=800, stop=['```', 'if __name__', 'def test'], name='program')\n",
    "    lm = lm.set('program', prompt + lm['program'])\n",
    "    return lm   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>Here is an implementation of solution:\n",
       "```python\n",
       "\n",
       "def solution(lst):\n",
       "    &quot;&quot;&quot;Given a non-empty list of integers, return the sum of all of the odd elements that are in even positions.\n",
       "    \n",
       "\n",
       "    Examples\n",
       "    solution([5, 8, 7, 1]) ==&gt; 12\n",
       "    solution([3, 3, 3, 3, 3]) ==&gt; 9\n",
       "    solution([30, 13, 24, 321]) ==&gt;0\n",
       "    &quot;&quot;&quot;<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>   </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> return</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> sum</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>lst</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>[</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>i]</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> for</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>i in</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> range</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> len</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>lst</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>),</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>if l</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>st</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>[</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>i]</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> %</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> !=</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>#</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> test</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> the</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> function</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>print</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>olution</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>([</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>5</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>8</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>7</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>]))</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> #</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> shoul</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>d print</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>print</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>olution</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>([</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>]))</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> #</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> shoul</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>d print</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>9</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>print</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>olution</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>([</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>4</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>]))</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> #</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> shoul</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>d print</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "idx = 121\n",
    "prompt = dataset['test']['prompt'][idx]\n",
    "lm = mistral + baseline(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is simple function to evaluate a generated program with the HumanEval evaluation tests:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Returns True if it passes the evaluation tests, False otherwise\n",
    "def eval_program(program, i):\n",
    "    # Loads the `check` function\n",
    "    exec(dataset['test']['test'][i])\n",
    "    try:\n",
    "        # Executes the function definition\n",
    "        exec(program, globals())\n",
    "    except Exception as e:\n",
    "        # Program not valid\n",
    "        return False\n",
    "    name = dataset['test']['entry_point'][i]\n",
    "    try:\n",
    "        # Run the unit tests\n",
    "        eval('check(%s)' % name)\n",
    "        # If we get here, we passed the tests\n",
    "        return True\n",
    "    except:\n",
    "        # The program ran, but the failed the unit test, or ran into some other exception\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12\n",
      "9\n",
      "0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_program(lm['program'], idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you run this prompt on all of HumanEval, you get 54.9\\% accuracy.  \n",
    "The model generates valid code (i.e. code that doesn't trip up the python interpreter) on 96\\% of examples, but the code only executes without exceptions in 93\\% of examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try another one:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>Here is an implementation of triangle_area:\n",
       "```python\n",
       "\n",
       "def triangle_area(a, b, c):\n",
       "    &#x27;&#x27;&#x27;\n",
       "    Given the lengths of the three sides of a triangle. Return the area of\n",
       "    the triangle rounded to 2 decimal points if the three sides form a valid triangle. \n",
       "    Otherwise return -1\n",
       "    Three sides make a valid triangle when the sum of any two sides is greater \n",
       "    than the third side.\n",
       "    Example:\n",
       "    triangle_area(3, 4, 5) == 6.00\n",
       "    triangle_area(1, 2, 10) == -1\n",
       "    &#x27;&#x27;&#x27;<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>   </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>if a</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> +</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> b</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> &gt;</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> c</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> an</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>d a</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> +</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> c</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> &gt;</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> b</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> an</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>d b</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> +</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> c</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> &gt;</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> a</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>:</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>       </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> =</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> (</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>a</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> +</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> b</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> +</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> c</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> /</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>       </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> return</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> roun</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>d(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> *</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> (</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> a</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> *</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> (</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> b</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> *</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> (</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>s</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> c</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>),</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>   </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> else</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>:</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>       </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> return</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "idx = 71\n",
    "prompt = dataset['test']['prompt'][idx]\n",
    "lm = mistral + baseline(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_program(lm['program'], idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that this time the generated program doesn't pass the evaluation tests.  \n",
    "But it's worse than that: the program doesn't even pass the first example in the docstring:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "36.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exec(lm['program'])\n",
    "triangle_area(3, 4, 5) # should be 6.00"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This suggests an improvement: let's extract the tests on the docstrings, and only return a program if it passes at least those tests.  \n",
    "First, let's write a simple prompt to extract the examples from the docstring into tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from guidance import any_char_but, regex\n",
    "@guidance(stateless=True)\n",
    "def test(lm, fn_name):\n",
    "    \"\"\"Only allows assert fn_name(args) == expected\"\"\"\n",
    "    return lm + '   assert ' + fn_name + '(' + capture(zero_or_more(any_char_but(['\\n'])), name='args') + commit_point(select([') == ', ') is ', ')' + regex('\\s\\s?\\s?\\s?') + '== '])) + capture(one_or_more(any_char()), name='result') + commit_point('\\n')\n",
    "\n",
    "@guidance\n",
    "def write_tests(lm, prompt):\n",
    "    r = re.findall('def (.*?)\\(', prompt)\n",
    "    name = r[-1]\n",
    "    lm += '```python\\n' + prompt + '    pass\\n'\n",
    "    lm += f'\\ndef test_{name}():\\n'\n",
    "    lm += '    \"\"\"Turns the example(s) in the docstring above into asserts\"\"\"\\n'\n",
    "    args = []\n",
    "    expected = []\n",
    "    # Write at most 10 tests, but stop when the model wants to stop\n",
    "    for i in range(10):\n",
    "        lm += test(name)\n",
    "        args.append(lm['args'])\n",
    "        expected.append(lm['result'])\n",
    "        if not lm.will_gen('assert', ignore_spaces=True):\n",
    "            break\n",
    "    lm = lm.set('args', args)\n",
    "    lm = lm.set('expected', expected)\n",
    "    return lm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>```python\n",
       "\n",
       "def triangle_area(a, b, c):\n",
       "    &#x27;&#x27;&#x27;\n",
       "    Given the lengths of the three sides of a triangle. Return the area of\n",
       "    the triangle rounded to 2 decimal points if the three sides form a valid triangle. \n",
       "    Otherwise return -1\n",
       "    Three sides make a valid triangle when the sum of any two sides is greater \n",
       "    than the third side.\n",
       "    Example:\n",
       "    triangle_area(3, 4, 5) == 6.00\n",
       "    triangle_area(1, 2, 10) == -1\n",
       "    &#x27;&#x27;&#x27;\n",
       "    pass\n",
       "\n",
       "def test_triangle_area():\n",
       "    &quot;&quot;&quot;Turns the example(s) in the docstring above into asserts&quot;&quot;&quot;\n",
       "   assert triangle_area<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>4</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>5</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> ==</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>6</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>.</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span>   assert triangle_area<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> ==</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span>   assert triangle_area<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>4</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>7</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> ==</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span>   assert triangle_area<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>3</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> ==</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>.</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>0</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span>   assert triangle_area<span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>(</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>,</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> </span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>2</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>)</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> ==</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'> -</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>1</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>\n",
       "</span></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lm = mistral + write_tests(prompt)\n",
    "args = lm['args']\n",
    "expected = lm['expected']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The LM went beyond extracting tests, it also generated a few of its own. While some of these may be incorrect, at least we have the original ones as well.  \n",
    "What's more, we already stored the inputs and expected results in the lm object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('3, 4, 5', '6.00'),\n",
       " ('1, 2, 10', '-1'),\n",
       " ('3, 4, 7', '-1'),\n",
       " ('3, 3, 3', '0.00'),\n",
       " ('1, 1, 2', '-1')]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# (input, expected output)\n",
    "list(zip(lm['args'], lm['expected']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's combine the baseline and the test generation prompts into a single guidance function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "@guidance\n",
    "def reconstruct_tests(lm, name, args, expected):\n",
    "    \"\"\"Helper to format tests nicely\"\"\"\n",
    "    lm += f'def test_{name}():\\n'\n",
    "    for arg, e in zip(args, expected):\n",
    "        lm += f'   assert {name}({arg}) == {e}\\n'\n",
    "    return lm\n",
    "\n",
    "@guidance\n",
    "def add_program_and_tests(lm, name, program, args, expected):\n",
    "    \"\"\"Helper to format program and tests nicely\"\"\"\n",
    "    lm += f'Here is an implementation of {name}:\\n'\n",
    "    lm += '```python\\n' \n",
    "    lm += program + '\\n'\n",
    "    lm += reconstruct_tests(name, args, expected) + '```\\n'\n",
    "    return lm\n",
    "\n",
    "@guidance\n",
    "def baseline_and_tests(lm, prompt):\n",
    "    lm2 = lm + baseline(prompt)\n",
    "    r = re.findall('def (.*?)\\(', prompt)\n",
    "    name = r[-1]\n",
    "    program = lm2['program']\n",
    "    lm2 = lm + write_tests(prompt)\n",
    "    args, expected = lm2['args'], lm2['expected']\n",
    "    lm = lm.set('program', program)\n",
    "    lm = lm.set('args', args)\n",
    "    lm = lm.set('expected', expected)\n",
    "    lm = lm.set('name', name)\n",
    "    lm += add_program_and_tests(name, program, args, expected)\n",
    "    return lm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>Here is an implementation of triangle_area:\n",
       "```python\n",
       "\n",
       "def triangle_area(a, b, c):\n",
       "    &#x27;&#x27;&#x27;\n",
       "    Given the lengths of the three sides of a triangle. Return the area of\n",
       "    the triangle rounded to 2 decimal points if the three sides form a valid triangle. \n",
       "    Otherwise return -1\n",
       "    Three sides make a valid triangle when the sum of any two sides is greater \n",
       "    than the third side.\n",
       "    Example:\n",
       "    triangle_area(3, 4, 5) == 6.00\n",
       "    triangle_area(1, 2, 10) == -1\n",
       "    &#x27;&#x27;&#x27;\n",
       "    if a + b &gt; c and a + c &gt; b and b + c &gt; a:\n",
       "        s = (a + b + c) / 2\n",
       "        return round(s * (s - a) * (s - b) * (s - c), 2)\n",
       "    else:\n",
       "        return -1\n",
       "\n",
       "def test_triangle_area():\n",
       "   assert triangle_area(3, 4, 5) == 6.00\n",
       "   assert triangle_area(1, 2, 10) == -1\n",
       "   assert triangle_area(3, 4, 7) == -1\n",
       "   assert triangle_area(3, 3, 3) == 0.00\n",
       "   assert triangle_area(1, 1, 2) == -1\n",
       "```\n",
       "</pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lm = mistral + baseline_and_tests(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['3, 4, 5', '1, 2, 10', '3, 4, 7', '3, 3, 3', '1, 1, 2'],\n",
       " ['6.00', '-1', '-1', '0.00', '-1'])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm['args'], lm['expected']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, if we have a generated program and a set of tests, we can write a guidance function that runs the tests and outputs the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to load the program\n",
    "def load_program(name, program):\n",
    "    error = None\n",
    "    try:\n",
    "        exec(program, globals())\n",
    "        fn = eval(name)\n",
    "    except Exception as e:\n",
    "        fn = None\n",
    "        error = e\n",
    "    return fn, error\n",
    "\n",
    "# Tolerance when x and y are floats\n",
    "def equals(x, y):\n",
    "    if isinstance(x, float) and isinstance(y, float):\n",
    "        return abs(x - y) < 0.00001\n",
    "    else:\n",
    "        return x == y\n",
    "\n",
    "@guidance\n",
    "def run_tests(lm, name, program, args, expected):\n",
    "    fn, error = load_program(name, program)\n",
    "    all_pass = True\n",
    "    lm += 'Running the test(s) above gives:\\n'\n",
    "    for arg, e in zip(args, expected):\n",
    "        # Reconstruct the test\n",
    "        lm += f'assert {name}({arg}) == {e}\\n'\n",
    "        try:\n",
    "            arg = eval(arg)\n",
    "            expected_result = eval(e)\n",
    "        except:\n",
    "            continue\n",
    "        try:\n",
    "            if isinstance(arg, tuple):\n",
    "                r = fn(*arg)\n",
    "            else:\n",
    "                r = fn(arg)\n",
    "        except Exception as ex:\n",
    "            r = ex\n",
    "        if equals(r, expected_result):\n",
    "            lm += 'Assertion passed.\\n'\n",
    "        else:\n",
    "            all_pass = False\n",
    "            lm += f'Assertion failed.\\n'\n",
    "            lm += f'Expected: {e}\\n'\n",
    "            lm += f'Actual: {r}\\n'\n",
    "        lm += '---\\n'\n",
    "    lm = lm.set('all_pass', all_pass)\n",
    "    return lm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>Running the test(s) above gives:\n",
       "assert triangle_area(3, 4, 5) == 6.00\n",
       "Assertion failed.\n",
       "Expected: 6.00\n",
       "Actual: 36.0\n",
       "---\n",
       "assert triangle_area(1, 2, 10) == -1\n",
       "Assertion passed.\n",
       "---\n",
       "assert triangle_area(3, 4, 7) == -1\n",
       "Assertion passed.\n",
       "---\n",
       "assert triangle_area(3, 3, 3) == 0.00\n",
       "Assertion failed.\n",
       "Expected: 0.00\n",
       "Actual: 15.19\n",
       "---\n",
       "assert triangle_area(1, 1, 2) == -1\n",
       "Assertion passed.\n",
       "---\n",
       "</pre>"
      ],
      "text/plain": [
       "<guidance.models._llama_cpp.LlamaCpp at 0x7fb18eae7a30>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral + run_tests(lm['name'], lm['program'], lm['args'], lm['expected'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can put this all together into a function that gets the LM to rewrite the program when the tests don't work:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "@guidance\n",
    "def run_tests_and_fix(lm, prompt):\n",
    "    lm2 = lm + baseline_and_tests(prompt)\n",
    "    name, program, args, expected = lm2['name'], lm2['program'], lm2['args'], lm2['expected']\n",
    "    i = 0\n",
    "    # Try this at most 3 times\n",
    "    while i != 3:\n",
    "        i += 1\n",
    "        lm2 += run_tests(name, program, args, expected)\n",
    "        # Passing the tests, I can stop.\n",
    "        if lm2['all_pass']:\n",
    "            break\n",
    "        lm2 += f'\\n'\n",
    "        # Get the model to think about what's wrong\n",
    "        lm2 += f'My implementation of {name} is wrong, because''' + gen(stop='\\n') + '\\n'\n",
    "        lm2 += f'In order to fix it, I need to''' + gen(stop='\\n') + '\\n'\n",
    "        lm2 += f'Here is a fixed implementation:\\n'\n",
    "        # Write a new program\n",
    "        lm2 += '```python\\n' + prompt + gen(max_tokens=800, stop_regex='\\n[^\\s]', name='program')\n",
    "        lm2 += '```\\n'\n",
    "        # Reset the slate, start over with new program\n",
    "        program = prompt + lm2['program']\n",
    "        lm2 = lm + add_program_and_tests(name, program, args, expected)\n",
    "        lm2 = lm2.set('program', program)\n",
    "        lm + 'ae' + gen(max_tokens=10)\n",
    "    return lm2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>...<span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>----------------</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>----------------</span><span style='background-color: rgba(165.0, 0.0, 0, 0.15); border-radius: 3px;' title='0.0'>----------------</span></pre>"
      ],
      "text/plain": [
       "<guidance.models._llama_cpp.LlamaCpp at 0x7fb18d6e33a0>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral + '...' + gen(max_tokens=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'>Here is an implementation of triangle_area:\n",
       "```python\n",
       "\n",
       "def triangle_area(a, b, c):\n",
       "    &#x27;&#x27;&#x27;\n",
       "    Given the lengths of the three sides of a triangle. Return the area of\n",
       "    the triangle rounded to 2 decimal points if the three sides form a valid triangle. \n",
       "    Otherwise return -1\n",
       "    Three sides make a valid triangle when the sum of any two sides is greater \n",
       "    than the third side.\n",
       "    Example:\n",
       "    triangle_area(3, 4, 5) == 6.00\n",
       "    triangle_area(1, 2, 10) == -1\n",
       "    &#x27;&#x27;&#x27;\n",
       "    # Check if the sides form a valid triangle\n",
       "    if a + b &gt; c and a + c &gt; b and b + c &gt; a:\n",
       "        # Calculate the semi-perimeter\n",
       "        s = (a + b + c) / 2\n",
       "        # Check if the triangle is right-angled\n",
       "        if a**2 + b**2 == c**2:\n",
       "            # Use Gauss&#x27;s formula for the area of a right-angled triangle\n",
       "            area = ((s * (s - a) * (s - b) * (s - c)) ** 0.5)\n",
       "        else:\n",
       "            # Use Heron&#x27;s formula for the area of a general triangle\n",
       "            area = ((s * (s - a) * (s - b) * (s - c)) ** 0.5)\n",
       "        return round(area, 2)\n",
       "    else:\n",
       "        return -1\n",
       "\n",
       "def test_triangle_area():\n",
       "   assert triangle_area(3, 4, 5) == 6.00\n",
       "   assert triangle_area(1, 2, 10) == -1\n",
       "   assert triangle_area(3, 4, 7) == -1\n",
       "   assert triangle_area(1, 1, 2) == -1\n",
       "   assert triangle_area(12, 5, 13) == 17.71\n",
       "```\n",
       "</pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lm = mistral + run_tests_and_fix(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.0\n",
      "-1\n"
     ]
    }
   ],
   "source": [
    "program = lm['program']\n",
    "exec(program)\n",
    "print(triangle_area(3, 4, 5))\n",
    "print(triangle_area(1, 2, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this particular case, having more rounds allows the model to fix its program on the unit tests. Does it also result in a program that passes the evaluation tests?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_program(program, idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yes. Indeed, this simple prompt modification raises accuracy from 54.9% to 57.3% for this model (we've seen bigger gains with larger models)\n",
    "\n",
    "Anyway, the point of this notebook is just to illustrate how easy it is to guide generation depending on what previous generations are (e.g. the test results depend on the current version of the code.)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "guidance_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
